from allennlp.predictors import Predictor

from .attacker import BlackBoxAttacker
from .candidates import WordNetCandidate, IBPCandidates
from config import Config
from tools.utils import softmax


class PWWSAttacker(BlackBoxAttacker):
    def __init__(self, cf: Config, predictor: Predictor):
        super(PWWSAttacker, self).__init__(cf, predictor)

        self.synonym_candidate = WordNetCandidate(self.supported_postag)
        self.unk_candidate = '@@UNKNOWN@@'

    def get_victim_substitute_pair(self, text):
        valid_pos, synonym_saliency = self.get_synonym_saliency(text)
        unk_saliency = self.get_unk_saliency(text, valid_pos)
        H = []
        for i, unk_s, (synonym_s, synonym) in zip(valid_pos, unk_saliency, synonym_saliency):
            H.append((i, unk_s * synonym_s, synonym))
        H.sort(key=lambda i: i[1], reverse=True)

        attack_num = self.attack_num(len(text['sentence']))
        H = H[:attack_num]
        H = [(h[0], h[2]) for h in H]

        return H

    def get_synonym_saliency(self, text):
        candidates = {'input': [text], 'pos': [], 'synonym': []}

        for i, (word, tag) in enumerate(zip(text['sentence'], text['tag'])):
            synonyms = self.synonym_candidate.candidate_set(word, tag)
            for w in synonyms:
                candidates['input'].append(self.subsitude(text, i, w))
                candidates['pos'].append(i)
                candidates['synonym'].append(w)
        outputs = self.predict_batch_data(candidates['input'])

        loss = [o['gold_prob'] for o in outputs]
        prob_x, prob_candidate = loss[0], loss[1:]

        candidate_star = dict()
        for i, p, word in zip(candidates['pos'], prob_candidate, candidates['synonym']):
            delta_p = prob_x - p
            if i not in candidate_star or candidate_star[i][0] < delta_p:
                candidate_star[i] = [delta_p, word]
        return candidate_star.keys(), candidate_star.values()

    def get_unk_saliency(self, text, valid_pos):
        candidates = {'input': [text]}
        for i in valid_pos:
            candidates['input'].append(self.subsitude(text, i, self.unk_candidate))

        outputs = self.predict_batch_data(candidates['input'])

        loss = [o['gold_prob'] for o in outputs]
        prob_x, prob_candidate = loss[0], loss[1:]
        delta_p = [prob_x - p for p in prob_candidate]
        return softmax(delta_p)
